{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "# Train a Semantic Segmentation Model using Segmentation-Models-PyTorch\n",
    "\n",
    "[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opengeos/geoai/blob/main/docs/examples/train_segmentation_model.ipynb)\n",
    "\n",
    "This notebook demonstrates how to train semantic segmentation models for object detection (e.g., building detection) using the [segmentation-models-pytorch](https://smp.readthedocs.io) library. Unlike instance segmentation with Mask R-CNN, this approach treats the task as pixel-level binary classification.\n",
    "\n",
    "## Install packages\n",
    "To use the new functionality, ensure the required packages are installed.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install geoai-py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "## Import libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import geoai"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "## Download sample data\n",
    "\n",
    "We'll use the same dataset as the Mask R-CNN example for consistency."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_raster_url = (\n",
    "    \"https://huggingface.co/datasets/giswqs/geospatial/resolve/main/naip_rgb_train.tif\"\n",
    ")\n",
    "train_vector_url = \"https://huggingface.co/datasets/giswqs/geospatial/resolve/main/naip_train_buildings.geojson\"\n",
    "test_raster_url = (\n",
    "    \"https://huggingface.co/datasets/giswqs/geospatial/resolve/main/naip_test.tif\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_raster_path = geoai.download_file(train_raster_url)\n",
    "train_vector_path = geoai.download_file(train_vector_url)\n",
    "test_raster_path = geoai.download_file(test_raster_url)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "## Visualize sample data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geoai.get_raster_info(train_raster_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geoai.view_vector_interactive(train_vector_path, tiles=train_raster_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geoai.view_raster(test_raster_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "## Create training data\n",
    "\n",
    "We'll create the same training tiles as before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out_folder = \"buildings\"\n",
    "tiles = geoai.export_geotiff_tiles(\n",
    "    in_raster=train_raster_path,\n",
    "    out_folder=out_folder,\n",
    "    in_class_data=train_vector_path,\n",
    "    tile_size=512,\n",
    "    stride=256,\n",
    "    buffer_radius=0,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "## Train semantic segmentation model\n",
    "\n",
    "Now we'll train a semantic segmentation model using the new `train_segmentation_model` function. This function supports various architectures from `segmentation-models-pytorch`:\n",
    "\n",
    "- **Architectures**: `unet`, `unetplusplus` `deeplabv3`, `deeplabv3plus`, `fpn`, `pspnet`, `linknet`, `manet`\n",
    "- **Encoders**: `resnet34`, `resnet50`, `efficientnet-b0`, `mobilenet_v2`, etc.\n",
    "\n",
    "For more details, please refer to the [segmentation-models-pytorch documentation](https://smp.readthedocs.io/en/latest/models.html).\n",
    "\n",
    "### Example 1: U-Net with ResNet34 encoder\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train U-Net model\n",
    "geoai.train_segmentation_model(\n",
    "    images_dir=f\"{out_folder}/images\",\n",
    "    labels_dir=f\"{out_folder}/labels\",\n",
    "    output_dir=f\"{out_folder}/unet_models\",\n",
    "    architecture=\"unet\",\n",
    "    encoder_name=\"resnet34\",\n",
    "    encoder_weights=\"imagenet\",\n",
    "    num_channels=3,\n",
    "    num_classes=2,  # background and building\n",
    "    batch_size=8,\n",
    "    num_epochs=5,\n",
    "    learning_rate=0.001,\n",
    "    val_split=0.2,\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "### Example 2: SegFormer with resnet152 encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geoai.train_segmentation_model(\n",
    "    images_dir=f\"{out_folder}/images\",\n",
    "    labels_dir=f\"{out_folder}/labels\",\n",
    "    output_dir=f\"{out_folder}/segformer_models\",\n",
    "    architecture=\"segformer\",\n",
    "    encoder_name=\"resnet152\",\n",
    "    encoder_weights=\"imagenet\",\n",
    "    num_channels=3,\n",
    "    num_classes=2,\n",
    "    batch_size=6,  # Smaller batch size for more complex model\n",
    "    num_epochs=5,\n",
    "    learning_rate=0.0005,\n",
    "    val_split=0.2,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "## Run inference\n",
    "\n",
    "Now we'll use the trained model to make predictions on the test image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define paths\n",
    "masks_path = \"naip_test_semantic_prediction.tif\"\n",
    "model_path = f\"{out_folder}/unet_models/best_model.pth\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run semantic segmentation inference\n",
    "geoai.semantic_segmentation(\n",
    "    input_path=test_raster_path,\n",
    "    output_path=masks_path,\n",
    "    model_path=model_path,\n",
    "    architecture=\"unet\",\n",
    "    encoder_name=\"resnet34\",\n",
    "    num_channels=3,\n",
    "    num_classes=2,\n",
    "    window_size=512,\n",
    "    overlap=256,\n",
    "    batch_size=4,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Output probability map (optional)\n",
    "\n",
    "You can also output the probability map by providing the `probability_path` parameter. This will save a multi-band raster where each band represents the probability for each class (0-1 range)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run inference with probability output\n",
    "probability_path = \"naip_test_probability_map.tif\"\n",
    "\n",
    "geoai.semantic_segmentation(\n",
    "    input_path=test_raster_path,\n",
    "    output_path=masks_path,\n",
    "    model_path=model_path,\n",
    "    architecture=\"unet\",\n",
    "    encoder_name=\"resnet34\",\n",
    "    num_channels=3,\n",
    "    num_classes=2,\n",
    "    window_size=512,\n",
    "    overlap=256,\n",
    "    batch_size=4,\n",
    "    probability_path=probability_path,  # Output probability map\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize probability map for building class (band 2)\n",
    "geoai.view_raster(\n",
    "    probability_path, indexes=[2], basemap=test_raster_path, backend=\"ipyleaflet\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also control the classification threshold for binary segmentation. By default, argmax is used, but you can specify a custom threshold:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run inference with custom probability threshold\n",
    "masks_path_threshold = \"naip_test_semantic_prediction_threshold2.tif\"\n",
    "\n",
    "geoai.semantic_segmentation(\n",
    "    input_path=test_raster_path,\n",
    "    output_path=masks_path_threshold,\n",
    "    model_path=model_path,\n",
    "    architecture=\"unet\",\n",
    "    encoder_name=\"resnet34\",\n",
    "    num_channels=3,\n",
    "    num_classes=2,\n",
    "    window_size=512,\n",
    "    overlap=256,\n",
    "    batch_size=4,\n",
    "    probability_threshold=0.3,  # Only classify as building if probability >= 0.7\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "## Vectorize masks\n",
    "\n",
    "Convert the predicted mask to vector format for better visualization and analysis."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_vector_path = \"naip_test_semantic_prediction.geojson\"\n",
    "gdf = geoai.orthogonalize(masks_path, output_vector_path, epsilon=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Add geometric properties"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gdf_props = geoai.add_geometric_properties(gdf, area_unit=\"m2\", length_unit=\"m\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "## Visualize results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geoai.view_raster(masks_path, nodata=0, basemap=test_raster_path, backend=\"ipyleaflet\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geoai.view_vector_interactive(gdf_props, column=\"area_m2\", tiles=test_raster_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gdf_filtered = gdf_props[(gdf_props[\"area_m2\"] > 50)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geoai.view_vector_interactive(gdf_filtered, column=\"area_m2\", tiles=test_raster_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geoai.create_split_map(\n",
    "    left_layer=gdf_filtered,\n",
    "    right_layer=test_raster_path,\n",
    "    left_args={\"style\": {\"color\": \"red\", \"fillOpacity\": 0.2}},\n",
    "    basemap=test_raster_path,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "## Model Performance Analysis\n",
    "\n",
    "Let's examine the training curves and model performance:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geoai.plot_performance_metrics(\n",
    "    history_path=f\"{out_folder}/unet_models/training_history.pth\",\n",
    "    figsize=(15, 5),\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![image](https://github.com/user-attachments/assets/9355446f-f9ba-4818-aedb-4bb5dee56813)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Performance Metrics\n",
    "\n",
    "**IoU (Intersection over Union)** and **Dice score** are both popular metrics used to evaluate the similarity between two binary masks—often in image segmentation tasks. While they are related, they are not the same.\n",
    "\n",
    "---\n",
    "\n",
    "### 🔸 **Definitions**\n",
    "\n",
    "#### **IoU (Jaccard Index)**\n",
    "\n",
    "$$\n",
    "\\text{IoU} = \\frac{|A \\cap B|}{|A \\cup B|}\n",
    "$$\n",
    "\n",
    "* Measures the overlap between predicted region $A$ and ground truth region $B$ relative to their union.\n",
    "* Ranges from 0 (no overlap) to 1 (perfect overlap).\n",
    "\n",
    "#### **Dice Score (F1 Score for Sets)**\n",
    "\n",
    "$$\n",
    "\\text{Dice} = \\frac{2|A \\cap B|}{|A| + |B|}\n",
    "$$\n",
    "\n",
    "* Measures the overlap between $A$ and $B$, but gives more weight to the intersection.\n",
    "* Also ranges from 0 to 1.\n",
    "\n",
    "---\n",
    "\n",
    "### 🔸 **Key Differences**\n",
    "\n",
    "| Metric   | Formula                     | Penalizes                      | Sensitivity                      |\n",
    "| -------- | --------------------------- | ------------------------------ | -------------------------------- |\n",
    "| **IoU**  | $\\frac{TP}{TP + FP + FN}$   | FP and FN equally              | Less sensitive to small objects  |\n",
    "| **Dice** | $\\frac{2TP}{2TP + FP + FN}$ | Less harsh on small mismatches | More sensitive to small overlaps |\n",
    "\n",
    "> TP: True Positive, FP: False Positive, FN: False Negative\n",
    "\n",
    "---\n",
    "\n",
    "### 🔸 **Relationship**\n",
    "\n",
    "Dice and IoU are mathematically related:\n",
    "\n",
    "$$\n",
    "\\text{Dice} = \\frac{2 \\cdot \\text{IoU}}{1 + \\text{IoU}} \\quad \\text{or} \\quad \\text{IoU} = \\frac{\\text{Dice}}{2 - \\text{Dice}}\n",
    "$$\n",
    "\n",
    "---\n",
    "\n",
    "### 🔸 **When to Use What**\n",
    "\n",
    "* **IoU**: Common in object detection and semantic segmentation benchmarks (e.g., COCO, Pascal VOC).\n",
    "* **Dice**: Preferred in medical imaging and when class imbalance is an issue, due to its sensitivity to small regions."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "geo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
